{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Construction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this section, we construct two classes to implement a basic feed-forward neural network. For simplicity, both are limited to one hidden layer, though the number of neurons in the input, hidden, and output layers is flexible. The two differ in how they combine results across observations. The first loops through observations and adds the individual gradients while the second calculates the entire gradient across observatinos in one fell swoop. \n",
    "\n",
    "Let's start by importing `numpy`, some visualization packages, and two datasets: the {doc}`Boston </content/appendix/data>` housing and {doc}`breast cancer </content/appendix/data>` datasets from `scikit-learn`. We will use the former for regression and the latter for classification. We also split each dataset into a train and test set. This is done with the hidden code cell below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 591,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "## Import numpy and visualization packages\n",
    "import numpy as np \n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn import datasets\n",
    "\n",
    "## Import Boston and standardize\n",
    "np.random.seed(123)\n",
    "boston = datasets.load_boston()\n",
    "X_boston = boston['data']\n",
    "X_boston = (X_boston - X_boston.mean(0))/(X_boston.std(0))\n",
    "y_boston = boston['target']\n",
    "\n",
    "## Train-test split\n",
    "np.random.seed(123)\n",
    "test_frac = 0.25\n",
    "test_size = int(len(y_boston)*test_frac)\n",
    "test_idxs = np.random.choice(np.arange(len(y_boston)), test_size, replace = False)\n",
    "X_boston_train = np.delete(X_boston, test_idxs, 0)\n",
    "y_boston_train = np.delete(y_boston, test_idxs, 0)\n",
    "X_boston_test = X_boston[test_idxs]\n",
    "y_boston_test = y_boston[test_idxs]\n",
    "\n",
    "## Import cancer and standardize\n",
    "np.random.seed(123)\n",
    "cancer = datasets.load_breast_cancer()\n",
    "X_cancer = cancer['data']\n",
    "X_cancer = (X_cancer - X_cancer.mean(0))/(X_cancer.std(0))\n",
    "y_cancer = 1*(cancer['target'] == 1)\n",
    "\n",
    "## Train-test split\n",
    "np.random.seed(123)\n",
    "test_frac = 0.25\n",
    "test_size = int(len(y_cancer)*test_frac)\n",
    "test_idxs = np.random.choice(np.arange(len(y_cancer)), test_size, replace = False)\n",
    "X_cancer_train = np.delete(X_cancer, test_idxs, 0)\n",
    "y_cancer_train = np.delete(y_cancer, test_idxs, 0)\n",
    "X_cancer_test = X_cancer[test_idxs]\n",
    "y_cancer_test = y_cancer[test_idxs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before constructing classes for our network, let's build our activation functions. Below we implement the ReLU function, sigmoid function, and the linear function (which simply returns its input). Let's also combine these functions into a dictionary so we can identify them with a string argument. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 544,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Activation Functions \n",
    "def ReLU(h):\n",
    "    return np.maximum(h, 0)\n",
    "\n",
    "def sigmoid(h):\n",
    "    return 1/(1 + np.exp(-h))\n",
    "    \n",
    "def linear(h):\n",
    "    return h\n",
    "\n",
    "activation_function_dict = {'ReLU':ReLU, 'sigmoid':sigmoid, 'linear':linear}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. The Loop Approach"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we construct a class for fitting feed-forward networks by looping through observations. This class conducts gradient descent by calculating the gradients based on one observation at a time, looping through all observations, and summing the gradients before adjusting the weights.\n",
    "\n",
    "Once instantiated, we fit a network with the `fit()` method. This method requires training data, the number of nodes for the hidden layer, an activation function for the first and second layers' outputs, a loss function, and some parameters for gradient descent. After storing those values, the method randomly instantiates the network's weights: `W1`, `c1`, `W2`, and `c2`. It then passes the data through this network to instantiate the output values: `h1`, `z1`, `h2`, and `yhat` (equivalent to `z2`).\n",
    "\n",
    "We then begin conducting gradient descent. Within each iteration of the gradient descent process, we also iterate through the observations. For each observation, we calculate the derivative of the loss for that observation with respect to the network's weights. We then sum these individual derivatives and adjust the weights accordingly, as is typical in gradient descent. The derivatives we calculate are covered in the {doc}`concept section </content/c7/concept>`. \n",
    "\n",
    "Once the network is fit, we can form predictions with the `predict()` method. This simply consists of running test observations through the network and returning their outputs.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 607,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FeedForwardNeuralNetwork:\n",
    "    \n",
    "    def fit(self, X, y, n_hidden, f1 = 'ReLU', f2 = 'linear', loss = 'RSS', lr = 1e-5, n_iter = 1e3, seed = None):\n",
    "        \n",
    "        ## Store Information\n",
    "        self.X = X\n",
    "        self.y = y.reshape(len(y), -1)\n",
    "        self.N = len(X)\n",
    "        self.D_X = self.X.shape[1]\n",
    "        self.D_y = self.y.shape[1]\n",
    "        self.D_h = n_hidden\n",
    "        self.f1, self.f2 = f1, f2\n",
    "        self.loss = loss\n",
    "        self.lr = lr\n",
    "        self.n_iter = int(n_iter)\n",
    "        self.seed = seed\n",
    "        \n",
    "        ## Instantiate Weights\n",
    "        np.random.seed(self.seed)\n",
    "        self.W1 = np.random.randn(self.D_h, self.D_X)/5\n",
    "        self.c1 = np.random.randn(self.D_h, 1)/5\n",
    "        self.W2 = np.random.randn(self.D_y, self.D_h)/5\n",
    "        self.c2 = np.random.randn(self.D_y, 1)/5\n",
    "        \n",
    "        ## Instantiate Outputs\n",
    "        self.h1 = np.dot(self.W1, self.X.T) + self.c1\n",
    "        self.z1 = activation_function_dict[f1](self.h1)\n",
    "        self.h2 = np.dot(self.W2, self.z1) + self.c2\n",
    "        self.yhat = activation_function_dict[f2](self.h2)\n",
    "        \n",
    "        ## Fit Weights\n",
    "        for iteration in range(self.n_iter):\n",
    "            \n",
    "            dL_dW2 = 0\n",
    "            dL_dc2 = 0\n",
    "            dL_dW1 = 0\n",
    "            dL_dc1 = 0\n",
    "            \n",
    "            for n in range(self.N):\n",
    "                \n",
    "                # dL_dyhat\n",
    "                if loss == 'RSS':\n",
    "                    dL_dyhat = -2*(self.y[n] - self.yhat[:,n]).T # (1, D_y)\n",
    "                elif loss == 'log':\n",
    "                    dL_dyhat = (-(self.y[n]/self.yhat[:,n]) + (1-self.y[n])/(1-self.yhat[:,n])).T # (1, D_y)\n",
    "                \n",
    "        \n",
    "                ## LAYER 2 ## \n",
    "                # dyhat_dh2 \n",
    "                if f2 == 'linear':\n",
    "                    dyhat_dh2 = np.eye(self.D_y) # (D_y, D_y)\n",
    "                elif f2 == 'sigmoid':\n",
    "                    dyhat_dh2 = np.diag(sigmoid(self.h2[:,n])*(1-sigmoid(self.h2[:,n]))) # (D_y, D_y)\n",
    "                    \n",
    "                # dh2_dc2\n",
    "                dh2_dc2 = np.eye(self.D_y) # (D_y, D_y)\n",
    "                \n",
    "                # dh2_dW2 \n",
    "                dh2_dW2 = np.zeros((self.D_y, self.D_y, self.D_h)) # (D_y, (D_y, D_h)) \n",
    "                for i in range(self.D_y):\n",
    "                    dh2_dW2[i] = self.z1[:,n] \n",
    "                \n",
    "                # dh2_dz1\n",
    "                dh2_dz1 = self.W2 # (D_y, D_h)\n",
    "                \n",
    "                \n",
    "                ## LAYER 1 ##\n",
    "                # dz1_dh1\n",
    "                if f1 == 'ReLU':\n",
    "                    dz1_dh1 = 1*np.diag(self.h1[:,n] > 0) # (D_h, D_h)                \n",
    "                elif f1 == 'linear':\n",
    "                    dz1_dh1 = np.eye(self.D_h) # (D_h, D_h)\n",
    "\n",
    "                \n",
    "                # dh1_dc1 \n",
    "                dh1_dc1 = np.eye(self.D_h) # (D_h, D_h)\n",
    "                \n",
    "                # dh1_dW1\n",
    "                dh1_dW1 = np.zeros((self.D_h, self.D_h, self.D_X)) # (D_h, (D_h, D_X))\n",
    "                for i in range(self.D_h):\n",
    "                    dh1_dW1[i] = self.X[n]\n",
    "                \n",
    "                \n",
    "                ## DERIVATIVES W.R.T. LOSS ## \n",
    "                dL_dh2 = dL_dyhat @ dyhat_dh2\n",
    "                dL_dW2 += dL_dh2 @ dh2_dW2\n",
    "                dL_dc2 += dL_dh2 @ dh2_dc2\n",
    "                dL_dh1 = dL_dh2 @ dh2_dz1 @ dz1_dh1\n",
    "                dL_dW1 += dL_dh1 @ dh1_dW1\n",
    "                dL_dc1 += dL_dh1 @ dh1_dc1\n",
    "            \n",
    "            ## Update Weights\n",
    "            self.W1 -= self.lr * dL_dW1\n",
    "            self.c1 -= self.lr * dL_dc1.reshape(-1, 1)           \n",
    "            self.W2 -= self.lr * dL_dW2            \n",
    "            self.c2 -= self.lr * dL_dc2.reshape(-1, 1)                    \n",
    "            \n",
    "            ## Update Outputs\n",
    "            self.h1 = np.dot(self.W1, self.X.T) + self.c1\n",
    "            self.z1 = activation_function_dict[f1](self.h1)\n",
    "            self.h2 = np.dot(self.W2, self.z1) + self.c2\n",
    "            self.yhat = activation_function_dict[f2](self.h2)\n",
    "            \n",
    "    def predict(self, X_test):\n",
    "        self.h1 = np.dot(self.W1, X_test.T) + self.c1\n",
    "        self.z1 = activation_function_dict[self.f1](self.h1)\n",
    "        self.h2 = np.dot(self.W2, self.z1) + self.c2\n",
    "        self.yhat = activation_function_dict[self.f2](self.h2)        \n",
    "        return self.yhat\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try building a network with this class using the `boston` housing data. This network contains 8 neurons in its hidden layer and uses the ReLU and linear activation functions after the first and second layers, respectively."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 625,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEcCAYAAAAoSqjDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3df3RU5bkv8O8zk0wyhGAwJoiAxSplmUtRTPzJWV2K1eOpLJELes+xIJ5aweuxtb09Fe0tq/bS3iXVc+hxtVzgaC1qqXpUDi57j8KhcHsP/gyiiNBcSgX5JQmRSAhhJsk894/Zs52Z7D2ZmezZe8/M97OWK5mdSfabLdnPft/nfZ9XVBVEREQAEPC6AURE5B8MCkREZGJQICIiE4MCERGZGBSIiMjEoEBERCYGBSIiMjEoEPmQiHxNRL7mdTuo/AgXrxH5i4icBWCD8fI6Ve30sj1UXhgUiHxGRH4JYB2AIICbVPXvPG4SlREGBSIiMjGnQEREJgYFIiIyMSgQGURkpIgMiMjYpGNTROSIiNSmvfcBEXkh7dg/ichjxueLReSQiHSLSJuIXOvk+YkKhUGByKCqJwH8EcAlSYcfBvA/VbU77e2/BfA1ERkFACISBHArgLUiMhnAvQAuVdVaAH8JYJ/D5ycqCAYFolTvwLgpi8hXADQBWJX+JlXdD+BdADcbh2YAOKWqbwIYAFAFoElEKlV1n6rudfL8RIXCoECUyrwpA/gZgCWqGrV571oAf2N8fpvxGqr6JwDfAfAQgHYReVZEzinA+Ykcx6BAlOodAJeIyBwAYcSHiez8C4CrRWQ8gNkwggIAqOpaVf0LAF8AoACWFeD8RI7jOgWiJCJSBeAEgCMA/quq/tsQ7/83ABUAzlLVacaxyQDGAdiKeEBYCSCgqnc4fX4ip7GnQJREVSMAPgCwL8sb8loAX0VSLwHxfMLDAI4B+ARAI4AfAPEgIiI/cPD8RI5iT4EoiYiEAPwJwK1G0riszk/EngJRqh8B2OrhDdnr81OZY1AgAiAil4jIZwC+AuBb5XZ+ogQOHxERkYk9BSIiMlV4cVIR2QegG/GVn/2q2iIiZwJ4DsBExEsC3KqqxzP9nBtuuEFfffXVwjaWiKj0iN0XvOwpXKOqF6tqi/H6AQCbVHUSgE3G64yOHTtWyPYREZUdPw0fzQKwxvh8DT6vKUNERC7xKigogA0isk1EFhrHxqjqEQAwPjZafaOILBSRVhFp7ejocKm5RETlwZOcAoDpqnpYRBoBbBSRP2b7jaq6GsBqAGhpaeHUKSIiB3nSU1DVw8bHdsQ3KL8MwNHE5iLGx3Yv2kZEVM5cDwoiUpPYRUpEagBcD2AngJcBLDDetgDAerfbRkRU7rwYPhoDYJ2IJM6/VlVfFZF3ADwvIncC+BjALR60jYjIdbGYorMnimj/AEIVQdTXhBAI2M4aLSjXg4Kq/hnARRbHOwEMuY8tEVEpicUUbUe7cddTrTh4vBfjR4fxz7e3YPKYWk8Cg5+mpBIRlZ3OnqgZEADg4PFe3PVUKzp7vNlwj0GBiMhD0f4BMyAkHDzei2j/gCftYVAgIvJQqCKI8aPDKcfGjw4jVBH0pD0MCkREHqqvCeGfb28xA0Mip1BfE/KkPV4tXiMiIgCBgGDymFqsu2d6ec4+IiKiVIGAoKG2yutmAODwERERJWFQICIiE4MCERGZGBSIiMjEoEBERCYGBSIiMjEoEBGRiUGBiIhMXLxGRFRECr33AoMCEVGRcGPvBQ4fEREVCTf2XmBPgYgc46dtJUuRG3svMCgQkSP8tq1kKUrsvZAcGJzee0FU1bEf5raWlhZtbW31uhlEBKCjO4LZK7YOumGtu2e6byqAFrtYTLGvswf7O09hRCiIU9EBfKF+BCbW1+QaeG3fzJ4CETnCb9tKlqpIfwxL1u9M6Y05iYlmInKE37aVLEVuJJoZFIjIEX7bVrIUMdFMREXDb9tKusHt2VZuJJrZUyAixyS2lRw3egQaaqtKPiC0He3G7BVbMX3ZZsxesRVtR7sRixVu8o4bvTHOPiIiyoNXs60c6p1w9hERkZO8mm2V6I0V7OcX7CcTEZWwUp1txaBARJSHUp1txeEjIqI8lOpsKwYFIqI8FXp83wsMCkREWSqHKrAMCkREWSiXKrBMNBMRZcGNukN+wKBARJSFcqkCy6BARJSFUl2XkI5BgYgoC6W6LiEdE81ERFko1XUJ6TwLCiISBNAK4JCqzhSR8wA8C+BMAO8CmK+qpZXBIaKiVorrEtJ5OXx0H4DdSa+XAViuqpMAHAdwpyetIs/FYoqO7ggOHT+Fju5IQUsRE1EqT4KCiIwHcCOAx43XAmAGgBeMt6wBcLMXbSNveVGjnog+51VP4ecA7gcQM17XA+hS1X7j9UEA46y+UUQWikiriLR2dHQUvqXkKj/OBWfPhcqJ6zkFEZkJoF1Vt4nI1YnDFm+1/MtT1dUAVgPxTXYK0kjyjN/mgpfLKtZyUQ5lKobLi57CdAA3icg+xBPLMxDvOdSJSCJIjQdw2IO2kcf8Nhfcjz0Xyg+HJrPjelBQ1QdVdbyqTgTw1wB+r6pfB7AZwFzjbQsArHe7beQ9v80F91vPhfLHAJ8dP61TWAzgWRH5CYDtAJ7wuD3kAb/NBU/0XNL34S21VazlgAE+O56uaFbVLao60/j8z6p6mapeoKq3qGrEy7aRdxJzwceNHoGG2ipPx3z91nOh/PltaNKvRLV4x9NaWlq0tbXV62ZQiWNysjRw0kAK21+YQYGIygYDvMn2l/ZTToGIqKDKoUzFcLFKKhERmRgUiIjIxKBAREQm5hSIShATqpQvBgWiEuP01EsGmPLC4SOiEuNkOQfWCyo/DApEJcbJcg7ZBhiWFy8dHD4iKjFO1mvKJsBwpXBpYU+BqMQ4Wa8pm3pBrD7qrkL3ythTIPK5XBO9TlaaTQSY9F5AcoBh9VH3uNErY1Ag8rF8bwJOlXPIJsCwvLh77Hpl6+6Z7lj5Dg4fEfmYH4ZmhiplzvLi7nGjV8aeApGPFcPQjN82RiplbvTK2FMg8jEvNobJJ5Hpp42RSpkbvTLup0DkY25P9+T0Uv9zaIU5N9khKlZulpno6I5g9oqtg4YnnExkki/Y/gPi8BGRzwUCgvqaEEIVQUT7B9DZEy3YiuFiyGFQYTHRTORzbg7pcHopsadA5HNuTkvl9FJiT4HI5+yGdHr7BnDo+ClH8wycXkoMCkQ+lJxcFhHLIZ297Sfxt79+x/HhJG5uX94YFIh8Jj2HcH1TI1bOa8bdz2wzcwqPzJ2Kde8ewqr5zagLV+KTz05jzKgqnFnDmzkND4MCkc+k5xA27GoHADy/6EokppD/07/vwaxp47D4xR1moFg1rxl14dyHerizGiVjopnIZ6xyCBt2tUNVMW70CIQqgvirL481AwIQzzEsemZbzsln7qxG6RgUiHxmqNIWo8OV+GJDjSPrCfxQcI/8hUGByGcyTQuNxRR7Ok7izx09jtRE4mI1SsecApHPZJoW2tEdwV1PtaJhZBWWzZmaklN46huXQaE5TVPlYjVKx6BA5EN200ITT/YHj/fi0dfasGRmE+rClZhYPwKfnurD7Stez2nVczY7q1F5YUE8oiJiV7Du+UVX4tZVb+RVyI6zj8oSC+IRlQK7fIOq5p0b4F4IlIzDR0RFxC7f0NkTZW6AHMGeAlGRsXqyZyE7cgp7CkQlgIXsyCkMCkQlgoXsyAkcPiIiIpPrQUFEqkXkbRF5X0Q+FJEfG8fPE5G3RGSPiDwnIhwMJSJymRc9hQiAGap6EYCLAdwgIlcAWAZguapOAnAcwJ0etI2IqKy5HhQ07qTxstL4TwHMAPCCcXwNgJvdbhuVtlhM0dEdwaHjp9DRHWElUAOvCyXzJNEsIkEA2wBcAOCXAPYC6FLVfuMtBwGMs/nehQAWAsC5555b+MZSSUjfuMbp3cqKFa8LpfMk0ayqA6p6MYDxAC4DcKHV22y+d7WqtqhqS0NDQyGbSSWEJaKt8bpQOk+npKpql4hsAXAFgDoRqTB6C+MBHPaybVRaci0RXS71gIZbOrtcrlM58WL2UYOI1BmfhwF8FcBuAJsBzDXetgDAerfbRqVrqI1rkpXTbmS5XJd05XSdyokXw0djAWwWkR0A3gGwUVVfAbAYwH8TkT8BqAfwhAdtoxKVSxmIchpSGU55jHK6TuXE9eEjVd0BYJrF8T8jnl8gclwuZSDyHVIpxqGU4ZTH4K5tpYllLqhsZFsGIp/dyIp5Fk++5TG4a1tpYpkLKhpuzafPZ0ilHIdSWJm1NLGnQEXB7SfxqooAls6aghGhIE5FB1BVkfn5qRyHUliZtTQxKFBRsHsSz2a7yXzOdfuv3s5pa8tsh1KKMe+QCSuzlh4OH1FRcPNJPJ9zZTOUkujt/Pd1O7Dz8Ans7+zBweOn0N8fc/x3IMpXTj0FEfl3AN9T1fcL1B4iS24kNRNP8QOqePKOS/HYpj3YfqArq3NlM5TS2RPF8o1tWHDVeVj84g5zGGzV/GZcePaoou4xUOnI2FMQkSYReSbp0P0AlovIkyIytrBNI/pcoZOayQuxvvKzLViyfifuv2Eypk2oy/pcVttkJov2D2BO8wQzIADxHsiip7fhWE+ERenIF4bqKWwCcGXihaq+C2CGiMwB8KqIvATgZ6raa/cDiJxQ6KSmVc7i+y/swHMLr3DsXImfYzU0dSoygHlPvFV001mp9AyVU7gewE+TD4iIAGgD8L8AfAvAHhGZX5jmEX1uqCfx4bDLIwCwPVeuU2Tra0JorK2yLCvx0bGesprOSv6VMSio6geq+vXEaxH5DwCHACxHvLT1HQCuBnCZiKwuXDOJCivXGkD51P0JBATnnBHPISQPg62a14zHNu1JeW+pT2cl/8p1SurdAD5U1fR/+d8Skd0OtYkoK05O70zkLNLXQdjlEfKdIltREcCFZ49KGQYLBoCOk5GU93FlMHklp6CgqjszfPnGYbaFKKPkIBAOBXH0RMSxxWy55iyGM0U2fW5/LKY5BSSiQnJs8ZpR0I6oINJXND95x6VYsn6no4vZclmI5eQUWa4MJj/h4jUqCunDNSNCQU/LSjg9RbaQSXSiXLDMBbkun1xA+nBNV29f3k/q2Zx/qPfw6Z5KFYNCmfG69k6+he3Sh2tWbtmLR+ZOxfdf2GE7Dm/1uwIY8vzZtpF1f6gUyeCJRMWjpaVFW1tbvW5G0XC70qjVTbmzJ4rZK7bmVGzOru1PfeMyjKyuQF9/bFCAs/td60eG8J9XvJ7x/B3dkbzaSFREbP/g2VMoI25WGk3clJdvbMOc5gmorwmhN9qPqopAXrmAXIdr7H7XtXddPuT5y7EMNlECg0IZcfNmZ1v8bV4zrm9qxIZd7eZ7s80FZBquSe+VxGIxy981KDJkLoI7ilE54+yjMpLrqt3hsC3+9sw2/PDGpiFn7eRSQsJqdfGxniiub2pMed/40WGEQ8EhZw0Nd2aRWzvEERUCcwolJlMi2c2cQkd3BPs7ezB35RuDvrZ18TUIVQRth4FybaddDmDtNy/HbY8PLjIHYNizj+wU817NVFZs/zEyKJSQbG5IhZp9lP5zR4crcfizXvOmnJBNwjbXRO+h46cwfdlmAMC0CXW4++rzUReuNHsGvVH3ZloxSU1FwvYPgcNHJSSbzeMLsUjKavhmT8dJjB1VPaj4WzbDMLnmPhLDYtMm1OFHNzUhFIz/s/5/R0+i61Qfxp4Rdm1BGJPUVOyYaC4hXt2QMs1qSi/+ls3Teq6J3kQO4OTpfvRGB8zyF+NHh/HI3KmoG1GJM2vceUpnkpqKHXsKJcTNRHKyTMEon55JronexHTVsXXV5mK2RBu+/8IO9Ebde0ov9A5xRIXGnkIJybX8s1OcfjpO3ORfvnc6eqMDGFBFdWXmnxUICGKqlsFpwMW0GctfULFjUCghXt2Q7ILR6HAlOrojebcl19LY1ZXWwam60t0OMctfUDHj7CNyhNXsoz0dJ/OemplpFk+iXEZ6sOF0UKKsscwFFVb603FHd2RYJTXs8hSxWCzjjZ9DN0TDw6BABZHLTCirtRPhUBBP3nEpRoSC6Ortw8ote9FxMoIBRcZgw6EbouFhUKCCyDb5HIsp9nX2YH/nKdSNqERtdSVisRiO9UQHTS0dM6oaapNM5joAImdwSipZGm79nmynZnb1RnH0xGksWb8Ts1e8jjuefBs90QEsenrboKml4cogKisCnky7JSoX7CnQIEMlbLMplZHt+H5vdGDQ2oJPe6KWvYFDXb0406Npt0TlgkGBBsm0Qrm+JpT1DJ9sxvcHLIaDOnuilkNPnT1RfOe59/DyvdOZTCYqEA4f0SCZksR2AeOTE6ezHmpKHpqqDAweDnpx2wH88rZLUoaels2ZipVb9uLg8V70RgdyWiXNUtZE2WNPgQbJlCS2CxiHu3oxd+UbQ64NSB+aur6pESvnNePuZ7aZPY8FV52H/73jEJ5deAU++ew0OnuiePS1Nmw/0JVz/oBrF4hyw8VrZErOFQzEFD/53S5s2NWeciO122N5ycwmLHp6m/nabj3Cpz0RvH/gs5Sppg21ITx00xSoKkQEQQECgcCQC+CschtA6l4JCh1yT2aiMsTFa5SZ1RP1qvnNWDprCgKBgDlub1XSYtmcqXj0tTbzZ2Vaj3Ck63TKVNPE96oqxo0eMeh77JLVdj2AqooAbv/V2+axZ+4cek9mIvqc6zkFEZkgIptFZLeIfCgi9xnHzxSRjSKyx/g42u22lTOrXMGip7chEAikjNsnzyrauvgaPL/oSqx5/SNsP9Bl/iy7IZ7OnigWPZM61XTxizvw7Wsn5Tyl1C63sb/zlHmsYWQV+mPKKaxEOfCip9AP4Huq+q6I1ALYJiIbAdwBYJOqPiwiDwB4AMBiD9pXlnJZgZw8qygWU3z3usnYdaQbDSOr8O1rJ+G8s2qgUMRimjJub3eO886qsZxS2t8fw+HPetHeHUFnTxQvbjuA7143GZPH1Fr+rIaRVZhYPwLPLbwCfQMxjKyuwM9e3Y1lc6aae0VzCitRZq4HBVU9AuCI8Xm3iOwGMA7ALABXG29bA2ALGBRck0v56/Sx/EkNI/HyvdNxpOu02ROwSujanWNEVdByj+S29m5zEVtiqGn5xjb8dPbUQT9r2oQ63H/DZMxPGjp6ZO5UdHTHk9RLZjahviaEc+rCOHtUNZPMRDY8TTSLyEQAfwAwBcDHqlqX9LXjqjpoCElEFgJYCADnnntu8/79+91pbInLdpaO3fvqR4YyJnRjMcWxnghORQbw0bEePLZpDzpORmxnAtlVSV0yswlTzhmFsWeEU9rx5B2XmrmK9PcnEuAAsHXxNZa5C6Iy479Es4iMBPAigO+o6gmR7J7cVHU1gNVAfPZR4VpYHLJZXZyNbFcg243l//Yu64Tu6b4B9PfHBs0iWjWvGWPrqlEXtm6v3VBTfU0IoYrgoPZaLYI7eLwXdeFK8zVzCURD82TxmohUIh4QfqOqLxmHj4rIWOPrYwG0e9G2YpJ4ap+9YiumL9uM2Su2ou1od96Ls7LZOtPuZh0QsUzoRoy8wKAk9jPbMBCDeY70BWbhkPXWoo21VWY+ILm94coKy/efMrbiZC6BKDtezD4SAE8A2K2q/5j0pZcBLDA+XwBgvdttKzZ2T+2dPdGCndNuH+hAAFg2Z+qgVci90X60d0cyJrGtgtvRExE89Y3LUn7eqvnNOOeMsGWwsivAd9GEM7B18TVYd890LlgjyoIXw0fTAcwH8IGIvGcc+wGAhwE8LyJ3AvgYwC0etK2o5DJjyCl2W28GAwGsef0jLJnZhLpwJbp6+7Dm9Y8wp3kCANgmsWMxxScnTlsGt+cXXWnu05y+RsFqyMx2+KumYJeDqOR4MfvoP2Cf5LjWzbYUu1xmDDnF7uYLAN+9brLloraG2hBWzW9OmUmU2MO57Wg3eiL9tqUzaqoqUp7wh0qIc5Uy0fCwzEUR81tdn0wzjCY1jMTx3r6UQJIombFkZhOWvrLLcubQ0ld2pZSkaO8+bTnL6aV7rkJjbbXrvzNRkfLf7CMaPr/tSRwICBprqxGrUdRUVeAXt01LaVP6U3xi+Gvllr2DFpglehnpw2Gn+6yHzE73xVz5HYlKHYNCkfPjkEm2bUoMf20/0IVHX2vDk3dcis96+zJWRQ0as5zSewpB5o+JHMH9FMgzyTOGth/owiOv/RHVlUEsfWWXGRDSp5GGQ0E8Mjd1ltMjc6ciHOL6AyInMKdAw2I3EygWU3T1RtEbjS8sq64M4qyawWsf0r9/dLhyUO4hfUX1vs4e7O88hRGhIE5FB/CF+hGYWF/D6aZE2WNOgZxnl+i+4KwafNIdQXQgho87T+GxTXvQUBvCD29sQjAgKTd7q6GmTENPgYBgYn0NaqsrfZFHISo1DAqUN6vFc8s3tuG+a7+UUhjvH265CFWVAdz2+FuOzJLyYx6FqFQwp1BE/LbXsNXiuTnNE8yAMG1CHZbMjPcORlZVomFk/EbuxsprIsoPewpFItPeAgByLornRCE9q8Vz9TUhMyD8/V9Otpxmuv1AF3c/I/Ip9hR8KrlX8GlPBG3t3bjt8bcwd+UbWPrKLiy46jws39iGrt7ooLpBu4+cwKc99j0JpwrpWdUbaqytwvjRYdx99flmQAA+32Xt7qvPN9/LiqVE/sPZRz6S/PQ+EFP85He7sGFXe8a9AqacMwr/ZfWbg762dNYUnH1GdU57FeS6mb3VTKALGmtwMjKAnkg/5q58Y9D3PLfwCnzvX973dOU1EXH2ke9ZzeRZNie+c9iIUNB2bwG7fQRGhIK466lWyxu9U4X0OnuiuN3Y6QyI73727Wsn4YLGGoyqrsD1TY3YsOvzCujjR4fN4MMZQ0T+xKDgE1YzeRa/uANLZjahq7fPchVvY20Vqiuti+J19fbZ3uidKqSXHFyscggr5zUDADbsajdnHI21KX1NRP7AnIJP2D2914UrsXLL3kGreFfNa0ZVRfx/X/q+A8vmTMXKLXsxfnQYAzEdlCuw23sg1w1okvdWsMoh3P3MNvz4pinYuvgarL3rctSP5AY3RH7HnoKLMs34sXt67+rtQ8fJCMaMqsZL91yFvv5YSr4hcUNff+9VONE7gH3HevDoa23oOBnBsjlT8Zs392HB9C9CVbPbeyAHieCyfGMbvtQ4Ev9wy0Xo6u3Dyi17zRlGkf4Y5j3hzPoEIio8JppdMlSZa6uvr5rfjLNqQggEAuZNO1OSOBaL4b2Dn5mb3GzadRSzpo1LGdJx+qbc3x9DW3t3yl4JiamnHScjWDprCv721+8MaisXnxF5yvYGwKDgkmxm/Fj1JIDUNQjR/gFMX7Z50M/fuvgahCqCKedYNb/Zcp+C4d6Uk9spIrh11RuWs58aa6vww3/die0Huga1ddzoEXmfn4iGzTYoMKfgkmxm/CRvRJ+4aaevJxiIqeUeyYkgkpwrOHtU9ZB7I+e6Qjp9jcPhrl7Lc5zfOBJj66rRcTJi2VYi8icGBZfYbXif6QZpNSPpJ7/bhVXzmy2TxMm5grcenIG6EZW258x3AVt6mzp7opbnCFcGURd2JqFNRO7h8JFL8tk689DxU5ZDRW89OAOBQMCyXHXysM5DL+/EgqvOS8kprJrfjAvPHmVuhZn8lH99UyMeumnKoKR0pjZZTUVNz5UMt5wGETmOi9e8luuMn1hMITa7jAUCgUE5gfSg88LdV2LDrnZ0dEexZGaTmXw+yzhn+nDWtAl1WHDVeWZ+wC5opc+S2n6gC2te/wjPL7rSMpiwoilRceHwkYvScwaZAkLb0W489PJOLJszNavhl2yGdULBACqNtQ3pw1lW6wysKplarXH47nWTcfao6iF/LyLyP/YUfCj5Bp940q+vCWHsGdW2206mP/mv3LIXv7htGnqjA/j+C0lDO/NbUBcOmTf3xHkS1U2TWa2IdmqNAxH5E3sKPpA+Cyj5Br/9QBcWPb0Nc1fGh3Vu+oV1Qjj9yX/7gS6c7ouZAQEwnv6fbsWxnkjKzX3r4mtwTl0460S41SwpP+3zQET5Y1DwmNUsILtpp4l6RtkO64wZVWX59H+6LwYg9eZ+9qjqvGYKOVWGm4j8gcNHeXByRk2maafpq4TXbz+EVfObUReO708ci2lKQjd9WKevf8AyUR20aWpVRQBLZ00xy2Anaivl2n676qxE5H8MCjnKZ2ppJlaL2jbsasfSWVOw7p7p6O0bwN72k1i//dCQJSvSZ/p82hPBI3OnpuQUHpk71TIvkV4GG8hu9bNTZbiJyB84fJQjuyfjfPcbtlvUlph2Or4ujLPPqMZffXlsVrODktWFQxgzqhpLZ03BcwuvwNJZUzBmVDXqwoOHhPK9ueezKI+I/Is9hSGkDxU5/WScPgsofSw/MSxUU2W90U6m8wYCgon1NaitrhxyqCvfPRZGhyux9puXD9o7mquWiYoTg0IGVkNFa795uSMb1CRkM8UzEBCEKyvyOm+2i8eGCk5WYjHFno6Tgyq7TmoYySmqREWKZS4ysKpsen1TI+776pdSksBu7BHgdC7D7hy5JNCd2uuZiFzHMhf5GCoJ7ObirUSP4qV7rsLpvhiCAtuFbMM5Ry43cyaZiUoPg0IGduPsVrWH3NJ5MlrQ3kIunNrrmYj8g7OPMnBqL2OnOD3zabj8dn2IaPjYU8jAb3V+vBiuyZRn8Nv1IaLhY1AYwnBKPzu9l4DbwzXZJLdZGpuotHD4qEAKURPI7eEavw1XEVHhedJTEJFfAZgJoF1VpxjHzgTwHICJAPYBuFVVj3vRPicUoiaQ28M1nF1EVH686in8GsANacceALBJVScB2GS8LlqFuqFmu1GPE1jCgqj8eBIUVPUPAD5NOzwLwBrj8zUAbna1UQ4rhRsqZxcRlR8/JZrHqOoRAFDVIyLS6HWDhiOfshF+w9lFROXHT0EhKyKyEMBCADj33HM9bo29UrmhcnYRUXnx0+yjoyIyFgCMj+1Wb1LV1araoqotDQ0NrjYwV26O/xMROcFPQeFlAAuMzxcAWO9hW4iIypInQUFEfgvgDQCTReSgiIkIJo0AAAUvSURBVNwJ4GEA14nIHgDXGa+JiMhFnuQUVPVvbL50rasNISKiFH4aPiIiIo8xKBARkanopqS6yemCdkREfsegYMON7S+JiPyGw0c2WCGUiMoRewo28i1oxyEnIipmDAo28tnQhkNORFTsOHxkI58KoRxyIqJix56CjXwK2nFTGiIqdmUXFHIZ88+1QqjbeygTETmtrIaPCrFvcjJuSkNExU5UnbkheqGlpUVbW1uzfn9HdwSzV2wd9CQ/nH2T03H2EREVAdubUlkNH7kx5s9NaYiomJXV8FEp7JtMRFRIZRUUOOZPRJRZWQ0flcq+yUREhVJWQQHgmD8RUSZlNXxERESZMSgQEZGJQYGIiEwMCkREZGJQICIiU1GXuRCRDgD7vW6H4SwAx7xuhE/x2tjjtcmM18fecK7NMVW9weoLRR0U/EREWlW1xet2+BGvjT1em8x4fewV6tpw+IiIiEwMCkREZGJQcM5qrxvgY7w29nhtMuP1sVeQa8OcAhERmdhTICIiE4MCERGZGBTyICK/EpF2EdmZdOxMEdkoInuMj6O9bKNXRGSCiGwWkd0i8qGI3GccL/vrIyLVIvK2iLxvXJsfG8fPE5G3jGvznIiU7QYfIhIUke0i8orxmtcGgIjsE5EPROQ9EWk1jhXkb4pBIT+/BpC+8OMBAJtUdRKATcbrctQP4HuqeiGAKwD8nYg0gdcHACIAZqjqRQAuBnCDiFwBYBmA5ca1OQ7gTg/b6LX7AOxOes1r87lrVPXipLUJBfmbYlDIg6r+AcCnaYdnAVhjfL4GwM2uNsonVPWIqr5rfN6N+B/4OPD6QONOGi8rjf8UwAwALxjHy/LaAICIjAdwI4DHjdcCXptMCvI3xaDgnDGqegSI3xgBNHrcHs+JyEQA0wC8BV4fAObwyHsA2gFsBLAXQJeq9htvOYh4EC1HPwdwP4CY8boevDYJCmCDiGwTkYXGsYL8TZXdzmvkDhEZCeBFAN9R1RPxhz5S1QEAF4tIHYB1AC60epu7rfKeiMwE0K6q20Tk6sRhi7eW3bUxTFfVwyLSCGCjiPyxUCdiT8E5R0VkLAAYH9s9bo9nRKQS8YDwG1V9yTjM65NEVbsAbEE871InIokHtPEADnvVLg9NB3CTiOwD8Cziw0Y/B68NAEBVDxsf2xF/mLgMBfqbYlBwzssAFhifLwCw3sO2eMYYB34CwG5V/cekL5X99RGRBqOHABEJA/gq4jmXzQDmGm8ry2ujqg+q6nhVnQjgrwH8XlW/Dl4biEiNiNQmPgdwPYCdKNDfFFc050FEfgvgasRL1x4F8CMA/wrgeQDnAvgYwC2qmp6MLnki8hcA/i+AD/D52PAPEM8rlPX1EZGpiCcEg4g/kD2vqv9DRL6I+NPxmQC2A5inqhHvWuotY/jo71V1Jq8NYFyDdcbLCgBrVfWnIlKPAvxNMSgQEZGJw0dERGRiUCAiIhODAhERmRgUiIjIxKBAREQmBgUiIjIxKBARkYlBgcghIvJlEdma9PoSEfm9l20iyhUXrxE5REQCiNfmGaeqAyKyGfG9Jd71uGlEWWOVVCKHqGpMRD4E8J9EZBKAjxkQqNgwKBA5603EK37eg8G78xH5HoMCkbPeRHy71l+q6iGP20KUM+YUiBxkDBv9HwCTVLXH6/YQ5Yqzj4icdR+ABxkQqFgxKBA5QETON7ZIDKvqmiG/gcinOHxEREQm9hSIiMjEoEBERCYGBSIiMjEoEBGRiUGBiIhMDApERGRiUCAiItP/B1CdeIJEV/hlAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "ffnn = FeedForwardNeuralNetwork()\n",
    "ffnn.fit(X_boston_train, y_boston_train, n_hidden = 8)\n",
    "y_boston_test_hat = ffnn.predict(X_boston_test)\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "sns.scatterplot(y_boston_test, y_boston_test_hat[0])\n",
    "ax.set(xlabel = r'$y$', ylabel = r'$\\hat{y}$', title = r'$y$ vs. $\\hat{y}$')\n",
    "sns.despine()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also build a network for binary classification. The model below attempts to predict whether an individual's cancer is malignant or benign. We use the log loss, the sigmoid activation function after the second layer, and the ReLU function after the first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 617,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9929577464788732"
      ]
     },
     "execution_count": 617,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ffnn = FeedForwardNeuralNetwork()\n",
    "ffnn.fit(X_cancer_train, y_cancer_train, n_hidden = 8,\n",
    "         loss = 'log', f2 = 'sigmoid', seed = 123, lr = 1e-4)\n",
    "y_cancer_test_hat = ffnn.predict(X_cancer_test)\n",
    "np.mean(y_cancer_test_hat.round() == y_cancer_test)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. The Matrix Approach"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below is a second class for fitting neural networks that runs *much* faster by simultaneously calculating the gradients across observations. The math behind these calculations is outlined in the {doc}`concept section </content/c7/concept>`. This class's fitting algorithm is identical to that of the one above with one big exception: we don't have to iterate over observations.\n",
    "\n",
    "Most of the following gradient calculations are straightforward. A few require a tensor dot product, which is easily done using numpy. Consider the following gradient:\n",
    "\n",
    "$$\n",
    "\\frac{\\partial \\mathcal{L}}{\\partial \\mathbf{W}^{(L)}_{i, j}} = \\sum_{n = 1}^N (\\nabla \\mathbf{H}^{(L)})_{i, n}\\cdot \\mathbf{Z}^{(L-1)}_{j, n}.\n",
    "$$\n",
    "\n",
    "In words, $\\partial\\mathcal{L}/\\partial \\mathbf{W}^{(L)}$ is a matrix whose $(i, j)^\\text{th}$ entry equals the sum across the $i^\\text{th}$ row of $\\nabla \\mathbf{H}^{(L)}$ multiplied element-wise with the $j^\\text{th}$ row of $\\mathbf{Z}^{(L-1)}$. \n",
    "\n",
    "This calculation can be accomplished with `np.tensordot(A, B, (1,1))`, where `A` is $\\nabla \\mathbf{H}^{(L)}$ and `B` is $\\mathbf{Z}^{(L-1)}$. `np.tensordot()` sums the element-wise product of the entries in `A` and the entries in `B` along a specified index. Here we specify the index with `(1,1)`, saying we want to sum across the columns for each.\n",
    "\n",
    "Similarly, we will use the following gradient: \n",
    "\n",
    "$$\n",
    "\\frac{\\partial \\mathcal{L}}{\\partial \\mathbf{Z}^{(L-1)}_{i, n}} = \\sum_{d = 1}^{D_y} (\\nabla \\mathbf{H}^{(L)})_{d, n}\\cdot \\mathbf{W}^{(L)}_{d, i}.\n",
    "$$\n",
    "\n",
    "Letting `C` represent $\\mathbf{W}^{(L)}$, we can calculate this gradient in numpy with `np.tensordot(C, A, (0,0))`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 634,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FeedForwardNeuralNetwork:\n",
    "    \n",
    "    \n",
    "    def fit(self, X, Y, n_hidden, f1 = 'ReLU', f2 = 'linear', loss = 'RSS', lr = 1e-5, n_iter = 5e3, seed = None):\n",
    "        \n",
    "        ## Store Information\n",
    "        self.X = X\n",
    "        self.Y = Y.reshape(len(Y), -1)\n",
    "        self.N = len(X)\n",
    "        self.D_X = self.X.shape[1]\n",
    "        self.D_Y = self.Y.shape[1]\n",
    "        self.Xt = self.X.T\n",
    "        self.Yt = self.Y.T\n",
    "        self.D_h = n_hidden\n",
    "        self.f1, self.f2 = f1, f2\n",
    "        self.loss = loss\n",
    "        self.lr = lr\n",
    "        self.n_iter = int(n_iter)\n",
    "        self.seed = seed\n",
    "        \n",
    "        ## Instantiate Weights\n",
    "        np.random.seed(self.seed)\n",
    "        self.W1 = np.random.randn(self.D_h, self.D_X)/5\n",
    "        self.c1 = np.random.randn(self.D_h, 1)/5\n",
    "        self.W2 = np.random.randn(self.D_Y, self.D_h)/5\n",
    "        self.c2 = np.random.randn(self.D_Y, 1)/5\n",
    "        \n",
    "        ## Instantiate Outputs\n",
    "        self.H1 = (self.W1 @ self.Xt) + self.c1\n",
    "        self.Z1 = activation_function_dict[self.f1](self.H1)\n",
    "        self.H2 = (self.W2 @ self.Z1) + self.c2\n",
    "        self.Yhatt = activation_function_dict[self.f2](self.H2)\n",
    "        \n",
    "        ## Fit Weights\n",
    "        for iteration in range(self.n_iter):\n",
    "            \n",
    "            # Yhat #\n",
    "            if self.loss == 'RSS':\n",
    "                self.dL_dYhatt = -(self.Yt - self.Yhatt) # (D_Y x N)\n",
    "            elif self.loss == 'log':\n",
    "                self.dL_dYhatt = (-(self.Yt/self.Yhatt) + (1-self.Yt)/(1-self.Yhatt)) # (D_y x N)\n",
    "            \n",
    "            # H2 #\n",
    "            if self.f2 == 'linear':\n",
    "                self.dYhatt_dH2 = np.ones((self.D_Y, self.N))\n",
    "            elif self.f2 == 'sigmoid':\n",
    "                self.dYhatt_dH2 = sigmoid(self.H2) * (1- sigmoid(self.H2))\n",
    "            self.dL_dH2 = self.dL_dYhatt * self.dYhatt_dH2 # (D_Y x N)\n",
    "\n",
    "            # c2 # \n",
    "            self.dL_dc2 = np.sum(self.dL_dH2, 1) # (D_y)\n",
    "            \n",
    "            # W2 # \n",
    "            self.dL_dW2 = np.tensordot(self.dL_dH2, self.Z1, (1,1)) # (D_Y x D_h)\n",
    "            \n",
    "            # Z1 #\n",
    "            self.dL_dZ1 = np.tensordot(self.W2, self.dL_dH2, (0, 0)) # (D_h x N)\n",
    "            \n",
    "            # H1 #\n",
    "            if self.f1 == 'ReLU':\n",
    "                self.dL_dH1 = self.dL_dZ1 * np.maximum(self.H1, 0) # (D_h x N)\n",
    "            elif self.f1 == 'linear':\n",
    "                self.dL_dH1 = self.dL_dZ1 # (D_h x N)\n",
    "            \n",
    "            # c1 #\n",
    "            self.dL_dc1 = np.sum(self.dL_dH1, 1) # (D_h)\n",
    "            \n",
    "            # W1 # \n",
    "            self.dL_dW1 = np.tensordot(self.dL_dH1, self.Xt, (1,1)) # (D_h, D_X)\n",
    "            \n",
    "            ## Update Weights\n",
    "            self.W1 -= self.lr * self.dL_dW1\n",
    "            self.c1 -= self.lr * self.dL_dc1.reshape(-1, 1)           \n",
    "            self.W2 -= self.lr * self.dL_dW2            \n",
    "            self.c2 -= self.lr * self.dL_dc2.reshape(-1, 1)                    \n",
    "            \n",
    "            ## Update Outputs\n",
    "            self.H1 = (self.W1 @ self.Xt) + self.c1\n",
    "            self.Z1 = activation_function_dict[self.f1](self.H1)\n",
    "            self.H2 = (self.W2 @ self.Z1) + self.c2\n",
    "            self.Yhatt = activation_function_dict[self.f2](self.H2)  \n",
    "            \n",
    "    def predict(self, X_test):\n",
    "        X_testt = X_test.T\n",
    "        self.h1 = (self.W1 @ X_testt) + self.c1\n",
    "        self.z1 = activation_function_dict[self.f1](self.h1)\n",
    "        self.h2 = (self.W2 @ self.z1) + self.c2\n",
    "        self.Yhatt = activation_function_dict[self.f2](self.h2)        \n",
    "        return self.Yhatt\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We fit networks of this class in the same way as before. Examples of regression with the `boston` housing data and classification with the `breast_cancer` data are shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 637,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAEcCAYAAAAoSqjDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3de5RcZZnv8e9T1Zd00g0dYidcEoVhxYwRo5BMFDLLxWVAlkaBE2RmlJuoEDOMd42X4ehMxnOMcI6Oy8EkzqgBZZQhZGSxjghGGMcoSkIUMZpBJJKEmG7abtLpdPpWz/mjqzbV1VXVVZWq2ruqfp+1stJVXbX325X0++z9Ps/7vubuiIiIAMTCboCIiESHgoKIiAQUFEREJKCgICIiAQUFEREJKCiIiEhAQUFERAIKCiIRZGZvNLM3ht0OaTymyWsi0WJmLwEeTD682N17w2yPNBYFBZGIMbN/BrYCceAt7v43ITdJGoiCgoiIBJRTEBGRgIKCiIgEFBREksys3czGzeyUtOfOMrODZtaR8dqPmdk9Gc/9k5l9Mfn1WjM7YGYDZrbHzC4q5/lFKkVBQSTJ3Y8AvwHOSXv6s8D/cveBjJf/G/BGMzsBwMziwFXAXWa2CLgZ+DN37wDeAOwt8/lFKkJBQWSyx0h2ymb2emAxsDHzRe7+e+Bx4PLkUxcCR939UWAcaAUWm1mzu+9196fLeX6RSlFQEJks6JSBzwG3uPtIjtfeBfx18uu3JR/j7r8F3g98Gug2s2+Z2akVOL9I2SkoiEz2GHCOma0C2pgYJsrl34HzzWw+cAXJoADg7ne5+58DLwMcWF+B84uUneYpiKQxs1bgMHAQeI+7f3ea138XaAJe4u5nJ59bBJwGbGciIGwAYu5+fbnPL1JuulMQSePuw8Avgb0Fdsh3AX9B2l0CE/mEzwLPA38A5gKfgIkgYmafKOP5RcpKdwoiacysBfgtcFUyadxQ5xfRnYLIZJ8CtofYIYd9fmlwCgoigJmdY2YvAK8H/rbRzi+SouEjEREJ6E5BREQCTWE34Hhceuml/sADD4TdDBGRWmO5vlHTdwrPP/982E0QEakrNR0URESkvBQUREQkoKAgIiIBBQUREQkoKIiISKCmS1KlPiUSTu/gCCNj47Q0xZkzq4VYLGcFnYiUkYKCREoi4ew5NMC779jB/r4h5s9u4yvXLmPRvA4FBpEq0PCRRErv4EgQEAD29w3x7jt20DuozcdEqkFBQSJlZGw8CAgp+/uGGBkbD6lFIo1FQUEipaUpzvzZbZOemz+7jZameEgtEmksCgoSKXNmtfCVa5cFgSGVU5gzqyXklok0BiWaJVJiMWPRvA62rlmh6iORECgoSOTEYkZXR2vYzRBpSBo+EhGRgIKCiIgEFBRERCSgoCAiIgEFBRERCSgoiIhIIJSSVDPbCwwA48CYuy8zs5OAbwOnA3uBq9y9L4z2iYg0qjDvFC5w99e4+7Lk448B29x9IbAt+VhERKooSsNHlwGbk19vBi4PsS0iIg0prKDgwINmttPMbkw+N8/dDwIk/56b7Y1mdqOZ7TCzHT09PVVqrohIYwhrmYsV7v6cmc0FHjKz3xT6RnffBGwCWLZsmVeqgSIijSiUOwV3fy75dzewFVgOHDKzUwCSf3eH0TYRkUZW9aBgZrPMrCP1NXAJ8CRwH3Bd8mXXAd+pdttERBpdGMNH84CtZpY6/13u/oCZPQbcbWbvBJ4F3hpC20REGlrVg4K7/w54dZbne4GLqt0eEZFakkg4vYMjFdtvRPspiIjUiETC2XNogHffsYP9fUPBzoSL5nWULTBEaZ6CiIjk0Ts4EgQEgP19Q7z7jh30Do6U7RwKCiIiNWJkbDwICCn7+4YYGRsv2zkUFEREakRLU5z5s9smPTd/dhstTfGynUNBQUSkRsyZ1cJXrl0WBIZUTmHOrJaynUOJZhGRkBVaURSLGYvmdbB1zQpVH4mI1KNiK4piMaOro7Vi7dHwkYhIiKpRUVQMBQURkRBVo6KoGAoKIiIhqkZFUTEUFESkbBIJp2dgmAN9R+kZGCaR0Or206lGRVExlGgWkbKoxhIM9agaFUVFtSeUs4pI3YlawrSWpCqKTps9k66O1lCDqIKCiJRF1BKmUhoFBREpi6glTKU0CgoiUhZRS5hKaZRoFpGyiFrCVEqjoCAiZVPpJRik8jR8JCIiAQUFEREJKCiIiEhAQUFERAIKCiIiElBQEBGRgIKCiIgEFBRERCSgoCAiIgEFBRERCWiZCxGRAiUSTu/gSLC20+y2ZvqGRutqrScFBRGRAmTuLHfJ4rm896KXs/obO+tqpzkNH4mIFCBzZ7lVSxcEAQHqZ6c5BQURkQJk7izX2dZclzvNKSiIiBQgc2e5/qHRutxpTkFBRKQAmTvLbdm5jw1XL627nebM3cNuQ8mWLVvmO3bsCLsZItIg6qj6KGcjQ6s+MrM4sAM44O4rzewM4FvAScDjwDXuXtsZGxGpK9l2lqu3nebCHD56H/DrtMfrgc+7+0KgD3hnKK0SEWlgoQQFM5sPvAn4l+RjAy4E7km+ZDNweRhtExFpZGHdKXwB+CiQSD6eA/S7+1jy8X7gtGxvNLMbzWyHme3o6empfEtFRBpI1YOCma0Eut19Z/rTWV6aNQPu7pvcfZm7L+vq6qpIG0VEoiqRcHoGhjnQd5SegWESifIWC4WRaF4BvMXM3gjMAE5g4s6h08yakncL84HnQmibiEhkZS61UYmlNap+p+DuH3f3+e5+OvBXwA/c/e3Aw8CVyZddB3yn2m0TEYmyzKU2KrG0RpQmr60FPmhmv2Uix/CvIbdHRCRSMpfagPIvrRHqKqnu/gjwSPLr3wHLw2yPiEiUpZbaSA8M5V5aI0p3CiIiFVXpJG2lZS61UYmlNbSfgog0hGokaSstFjMWzetg65oVFVtaQ3cKItIQqpGkrYbUUhunzZ5JV0dr2QOagoKINIRqJGnrgYKCiDSEzP0QoD72Pyg3BQURaQjVSNLWAyWaRaQhVCNJWw8UFESkYWTbD0Em0/CRiIgEdKcgUocyt43UMIkUSkFBpM6Ue5KWAkxj0fCRSJ0p5yStVIC54vbtrFj/MFfcvp09hwZqbnkIKZyCgkidKeckrXqZBSyFU1AQqTPlnKSlWcCNR0FBpM6Uc5KWZgE3HnOv3bHBZcuW+Y4dO8JuhkjkHE9yOP29bS1xDh0erumVRSWrnP94qj4SqUOlTtLKVrl0xw3LuXfNeYyOJVR91AA0fCQigWyJ5Wu/+jMMq9hSzRItulMQkUCpiWXNZagfCgoiNSqzI57d1kzf0Ohxdcyl7AFcDzuayYs0fCRSYyaCwTF2HzwcTCr75NYn+E2Jk8zS9y12nDtuWF5U5ZLmMtQX3SmIZBH2cEiu86euyv/wwjFu+c6TQUe8aukCVn9j55SOeeuaFXkTzrmu8u+7eQVDI4X97JrLUF90pyCSIeylHfKdP3VVPrMlPqkj7mxrLqljznWVP56g4MSy5jLUFwUFkQxhD4fkO3/qqrx/aHRSR5z5GArrmMtxla8dzeqLgoJIhrCHQ7Kdv6u9lZGxccbd+dr1f8a23YdYv2pJ0BFv2bmPL1+9tOiOuRxX+ek7mm1fewFb16xQkrmGKacgkqGUCpxKnv+qpfNZff6ZHHzhGL2DI2zZuY93rDiDrY8fYN1lZ/GyOTNpjsc4uaO16K0mU1f5mTmFYq/ytaNZ/dAyFyIZwi6xTD9/V3srf3/ZK1nzzceDtqxftYTNP36G//nmV2JAW0uczrbSE+FjYwm6jwwzNp6gKR5jbnsrTU0aRKhzOf+zKCiIZBGl6qO/3PTolLuWW1Yu5qxTT+C02TOP+zyaY9CQcv7j6nJAJIvUcEhYSzukD8dky2/MmdVSluGssJPqEj0KClKz0idd9QwM1+VuYLkSwXM7WstS3RN2Ul2iR0FBalLYcwmqJVu558ZrlnLqiW1luXvRHAPJpJyC1KSegWGuuH37lLH26Wbw1qJK5jeUU2hY2k9B6ksjDXtUstwzfY6BVjgVUFCQGhX2XIJ6ojkGkq6onIKZfd/MXl2pxogUSksriFRG3pyCmS0GPuHuVycfnwPcBvw++fzBok9oNgP4IdDKxJ3KPe7+KTM7A/gWcBLwOHCNu+eti1NOobGFPZdA9G9Qw0rOKWwDzk09cPfHgQvNbBXwgJndC3zO3YdyHSCLYeBCdz9iZs3Aj8zsu8AHgc+7+7fMbAPwTuDLRRxXGoyGPcKlJHU4Kh2Ipxs+ugT4TPoTZmbAHiY67L8FnjKzawo9oU84knzYnPzjwIXAPcnnNwOXF3pMEak+TXyrvmqUYucNCu7+S3d/e+qxmf0IOAB8HjgNuB44H1huZpsKPamZxc3s50A38BDwNNDv7mPJl+xPHj/be280sx1mtqOnp6fQU4pImTVSBVhUVCMQF1t9tBr4lU9NRPytmf260IO4+zjwGjPrBLYCr8j2shzv3QRsgomcQqHnFJHyUgVY9VUjEBdVfeTuT2YJCClvKvbk7t4PPAK8Dug0s1SQmg88V+zxRKR6VAFWfdWYgV71Gc1m1gWMunu/mbUBDwLrgeuALWmJ5ifc/fZ8x1L1kUi4VH1UXWVM7kdn6WwzW8JEIjnOxJ3K3e7+D2b2J7xYkroLuNrdh/MdS0FBqkEdn0RJmf4/RmeZC3d/Ajg7y/O/A5ZXuz0i+VSi7FJBRo5HpUuxtUqqSB7lrvZolNVdpXYpKIjkUe5qD9X2S9QpKIjkUe5qD9X2S9QpKIjkUe6yS21qI1GnTXakZoSVoC3neXMlrhd2tdM3NKrks1RLdEpSy0lBoXGUswoo7OqfzPPPbmvmqZ4jWlhOqinnfywNH0lNKFeCNgrVP6mSwtNmz6Sro5W+oVElnyUyFBSkJpQjQZtIOH84fCxyHbCSzxIl2o5TasLxLr6WukMYHB4rqAMuZYgp/T1mRtwgFotN+14tLCdRojsFqbpEwukZGOZA31F6BoYLGroptAoo17FTw0+9gyM5q3/S33ugf4i9zw+yv2+IJw+8wN7ewbztzByWumrjT/htzyCf3PrEtMNTWlhOokSJZqmq40kYT3f1nu/YB18YYsX6hzl7QScffsMi1m55Ykr1T2ay99Yrl/C5B/bQc2SYW69cwqKTOzhpVvblBXoGhrni9u1TrvZvWbmYdffvZuuaFXmXJgg7+S0NR4lmiYZCE8bZrvgzE7SZnWa+Y6eGaHbt6+e27+3hlpWLuWf1udx907ksmteRNdn7kXueYPX5ZwZfD43kHuPPlRfobGsuKD8w3c8mUi0KClJVhSRVS60Qynfs9CGaXfv6WXf/bma1NnHyCTOIxSxvp576ejzP6XNNSusfGlV+QGqKgoJUVSEzekstP8137FjMWDSvg61rVrB97QVsXbNi0pBVvk499fWM5ty/LtnyAutXLWHLzn3KD0hNUU5BqqqQnMKBvqOsWP/wlPduX3sBp82eeVzHLua96TmFQo5TavWRSAg0o1miY7qkaipp29Xeyurzz6SzrZmjI+O8esGJORO92Y7d3BSjKWYMjRSWvD2e94rUGAUFqR2JhLO3d5BDh4/xkXvSqoSuWca8E1sL6qgrsTmOSB1R9ZHUjljMaJ/RFAQESOYV7tzBL/a9UFDyWfsWiJRGQUEiaXQskbUaaGZLPPg6Xycf1aUjSpm4J1JNWuZCilaNiVa5ln5IVQNB/k4+iktHaEhLaoHuFKQolV5lNHUlPTI2zl3vei2XLJ4LEFQDbXjkac5e0MnGa5Zyz+pzMbOs547i0hEa0pJaoDuFOlPpq/hcHVuuZRyKaU+2K+mN1yxl3WVnYTHjyLExujpauO68M6YsU5F5tZ0+LyEqS0dEdUhLJJ3uFOpINfYKKKZjK7Y92QLOTXfuJBaLMbdjBqfPmcWn33JWEBBSr3n3HTt4fnC46GUxqk1bcUotUFCoI9UYniimYyu2PdMFnFjMcPesrzk6PJ4z+EQluRvFIS2RTBo+qiPVGJ5IdWyZydJsHVux7Wluik2bHM6VQH4mucx16hypIa05s1oqltwtdqguikNaIpkUFGpc5tIKlyyey4O7u4Pvl3t4Il/HltlJFtLJp4yNJRgeTbD5huU823uUL257KlheIj3gZAtKG69eyt/9x5OTjpcKPsXmQApVaiVRakhLJKoUFGpYto5pw9VLAXhwd3fFhieydWzZ2nLHDcsLuqtIJJw93QPcdOfO4HVffvs5HBtNMO+E1mkTyPEY9BwZ5uwFnXzokpdz8okziMeM5niMRCL7fIfjvXuqVLARCZuWuahh2TZ2uWTxXD715lcylnCa4zHmtrfS1PRi6qgS1UmpvY+v2viTKXcF9928gvEEec+X6+f4yBv+lNamGE3xWN7F5VLLYvQeGeYDd/9i0oJ2XR2tfO6B30y5ezrezrvURftEIiLnL73uFGpY5pj92Qs6ue68M/jLTY9mHdKoxOSp6fY+HhoZn3Zl06HRsaw/xzu+/ljQzvWrlrD5x8/wgYsXZS0/bZ/RxLVf/dmUTXLWXXYWf/emxew+ODBtDqQYUZwcJ1IOqj6qYZmVQKvPPzNruWaq2qeYaqBCK3YK2fs413G7B46xt3eQp7sHp/051m55gmvPPT1ne/MtixGPWc59FEqlSiKpV7pTqGGZSdc5s1ryjp8XWg1UzB1F6pgbHnma9auWTJlUFo9NDLW0NMWZ3dacdR/k7/7y4KT35vo5Tulso6u9NWs+INeV+9GRiQR8ucf5VUkk9UpBoYZldkxmlndIo9Ahj+mSqJkVT5l7H8+Z1cKpnW0cGx3nLV/aHgSAu9712qz7IN965RIA7rxhOePutDVnb+ezvUd570ULaWuJB0thpDrjObNa2HjN0knJ6luvXEJbS5x4hfppVRJJPdLwUY1Ln7V78gkz8g5pFDrkke+OInOW8qfve5INVy+dsvfxjObYlDH+7oHhKcftam+lvXVimewL/s9/cv3XHmNodJyN1yydsrXlF7c9xcvntXPo8PCUiWoAp5w4g9ve+mq+/8HXc+cNywG4/eHfEovpv7lIoXSnUEemG9IodMgj3x1F5l1Eqqrn7pvOxd2DYx58YWhKAEjlHdKff+9FC3nPNx+fFDyu/9pj3Pue81h32VnMbInTPzTKbd+b2BbTzHLexcyZ1cKJbc1c/7XHJg1haZxfpHAKCnVmuiGNQoY88s1aztbZP7i7m0+92SdVGWULLFt27psyxHPGS2ZlvSsZHU9w8okzprQh1zIXI2PjGucXKYOqBwUzWwDcAZwMJIBN7v5PZnYS8G3gdGAvcJW791W7fY0sPVcwp72F+25eMWXry0LzEtkCywcuXsTCrvZJnbbjOY+3aF4b9645j2OjCeIGbS1xxhK5Xw8a5xc5XlWfvGZmpwCnuPvjZtYB7AQuB64H/ujunzWzjwGz3X1tvmM1+uS1ciq04ijf64BJCeiWuDEy7pOGlbJNPMt3vGyzpIfHEgVVRlVjMyCRGpXzFyH0Gc1m9h3gS8k/57v7wWTgeMTdF+V7r4JC+WSbVZxr5m+2zhamduD5JpylH2tv7yC/7z3KzJY4R0fGedmcmZw+Zxa9gyNZ21TILGntciaSVzRnNJvZ6cDZwE+Bee5+ECAZGObmeM+NwI0AL33pS6vT0AZQzIqmmUM0qWUuBofHuGXlYrbtPsRFi+fR2hTjI2/4U2793m/4xytexdyOGVOO1Ts4MqlKCV4MRrnaNN0s6dRxtTaRSPFCCwpm1g5sAd7v7ofNCrt6c/dNwCaYuFOoXAsbS6nLNmS7Ir/97efwpR88FSzKt37VEjzhdA8cY3QsMenqPl8wytWm5qbYlHkKuSbVZTuuiOQWSgG3mTUzERC+6e73Jp8+lBw2SuUdunO9X8qv1GUbsl2Rr/nm46xauiB4vPnHz9BzZIT/cfuPp2yCk2/TnlxtOnJsbNrd3LTLmUhpwkg0G7CZiaTy+9OevxXoTUs0n+TuH813LOUUyquUxGyu1UK/fePr+MtNjwJM7LN8/+4pV/zfvvF1tLXEOXR4OOfYf2ab4jGCWdJnL+hk9flnBjOoTz5hRtBe5RRE8opUTmEFcA3wSzP7efK5TwCfBe42s3cCzwJvDaFtkVbpappSyjlzDfH0D40Gj3OtZbS/b4gP/fsvuOOG5dy75rwpQ0vZ2nSg72gQED78hkVT1lpKdfqasyBSmqoPH7n7j9zd3H2Ju78m+ef/uXuvu1/k7guTf/+x2m2LsszlJXINm1Rb54wmNl49eUmKDVcvZcvOfZy9oJOvXf9nzDthRtahnP6hUfb3DXHtV3+GYZw2eyZdHa15O+5UEJpuRViYvATIdMcVkQma0VwjolhNMzaWYE/3Eb647b+DhfC6Olo5paOV/73qVRzsH2b1N3bS1d7KrVcu4SP3vHhVv37VEm773p7gZyk0AZzKM+Tav0GJZJHjo6BQI6JYTdN95MVOHyAeM0bHneePjgDGF7f9dzBM9LkH9rDusrM4s2sWT/cMctv39rBrXz9QXAI4NSz0h8PHtMmNSAVo+cgaEcVqmtHxBF3trXz4DYtYd/9urrj9x1z/tZ/xzPNH+fR9T3LdeWdw9oJOAHbt6+cdX3+MeMw4+cQZ9BwZDn6GYheti8Vs2hVhRaQ0oc9oPh6NVH0UxWqa5/qHONA3xJHhsWA10w2PPE3PkWFuWbmYdffv5paVi7npzp3AixVHzU0xmmI2ZV2lYmkZC5GSRar6SEoQpWqaVGfcEjfaWuJ84O6fT8kVdLY1B7uoAcGmNzfftYueI8NlCWha/E6k/DR8VEOiUE2TXgW189l+Vn9j55S9lN970UL6h0aZP7uNzpktfP+Dr2fdZWfxuQcm8gj59oYWkXApKERU+gb3PQPDjI0lJj1OL0XNfG2h3ytFehVU6m4g3f6+IV46ZyZbdu6bqDj691/Qe2SEd3z9sSCxnHqdKoVEokfDRxGUmT+4ZPFc3nvRy4Or8umWly7ke9lmCxcyHJVeBZW6G8isAGprivGpN7+Sm+/axa59/Tlfp0ohkejRnUIEZc5JWLV0wZRhmtTwS675C9N9b2wswf6+o/y+d5AnnzvMJ7c+UdBkuPQqqA2PPM36VUumVACdfOJEh5+qMMr1OlUKiUSP7hQiKHNOQq5hmqGRMeIxo6u9ddL304dmsr0vkUiwp3tg0raY61ct4fMP7eEzVyzJm7xN31Ft175+Nv/4Ge5612uJJ3dlS91tFPo6EYkWBYUIylxPKNfwy6//MMC6+3dz65VLgiRu6nupoZls7xt3goAALyaIb1m5eNpx/kKroKJULSUihdPwUQRlLhm9Zec+vpyxvtD6VUvY8MjT7O8b4iP3TFT8pL638eqlxGMwu6056wQvd896BzFnVktB4/yFVkFFoVpKRIqjyWsRlZ4EHk8433x0L+ecPoeFc9t5qvsIGx55elI1zw8/ej6JBDzz/CBf3PZUMBdgYVc7fUOjk67Wc21zede7Xsv82TPLNpEM0OQykWjS5LVak7rKnrR38n/tzbk3QVMsxlVf+cmk53MtmJc+3p/KKWy8ZimnnthWckDIVuXU2hQLttqMwgxsEZmeho8iJNucgsykc65KnrhlTyrn2mM5Nd6/fe0FbF2zglecfAJNTaX9d8hV5fT73qN5l7YWkejRnUIFFTMPINfV9rwTWicli1OVPHffdC7uPmlIqJi5AIUuEVHIz5BrBdeZLfEpz2nCmki0KShUSLZO/o4bltM+oynrDmO5rrbvXXPelKGeD1y8aNLWk5B9SOh45wIUughfrt3Xjo6MB1tmdrY1c3RknLYWTVgTiTIlmitkUi4AOHtBJx+9dNGkjWbSO9hcex1vX3sBp5zYVtAdR7lXDc38GWCis8/MU+QKHu2t8aA6SnkFkUhRornaModUVp9/ZtA5wtSd03Jdbbe1xAvu6Mu9amihG/vkmpPw/OBw3p9ZRKJHieYKydwUJ9es5FQHmzk3ITXcdOjwcGj7MhezsU+2OQmjY4nI7RYnIvkpKFRIZid/dGQ8bwebWRF0380rmNEcz7l2USVkVj/lmvxWaJ4iirvFiUh+yilUUPoYf1tLnEOHhwvaOS01Rj84PMaVG34y5bjb117AabNn5j1fsTmFXHmBbJPfjveYyimIhC7nL6CCQhUV2mmnErypLS2nS/Smjl1MtVOucxZyrkr8zCJSVUo0R0GhieBUgjc1UW3tliemLTPNLGntam/l0OFjXPvVwip/Ck0qF0tbZorUFgWFElT66jc1Fr9rXz+3fW8Pt6xczJxZLZza2TZlfkJKsdVOuc6pjXBEGpsSzUVK36O4UhVB6UnqXfv6WXf/bma1NuUMCFB8tVO+c4I2whFpVLpTKFKumcflrL0vZS+CzBnNqWqnYpa9WNjVzt03ncvYeIKmeIy57VruWqTRKCgUqRJj77mGo4oJMpmBpK0lXtSyF4mE81TPEVUKiTQ4BYUilXvsvZiyzelyGZmBpLOthXvXnMex0QRxI++6Q9W4AxKR6FNOoUjlHnvP1RlnTlArNZfRe2SEt33lUVasf5i3fCn3eypVfSQitUVBoUjZ9iI4niGWQjvjQoNHqe/R7GMRAQWFkpRz7+FCO+NSruSLeY+qj0QElFMIXaH7IJSSyyjmPaVUPIlI/Wm4ZS6iuOxCIW0qZR0hrT0kIjlEa+0jM/sqsBLodvezks+dBHwbOB3YC1zl7n35jlNsUKj1TrKUgBbFICgiocvZCYSVU/g6cGnGcx8Dtrn7QmBb8nFZlZKsjZJSchnlzH+ISP0LJSi4+w+BP2Y8fRmwOfn1ZuDycp+3lGRt5h4D1drgplbaIyL1JUqJ5nnufhDA3Q+a2dxyn6DYZO3YWILnXhiie2CY3sERtuzcxwcuXhTacFOtD3+JSPTVXEmqmd1oZjvMbEdPT09R7y2m7DKRcPZ0D/C2f/kpV274Cevu3811553B5x/aE9pwUynDX7qzEJFiROlO4ZCZnZK8SzgF6M72InffBGyCiURzMScopuyyd3CEm+7cOakDXrvlCW5ZuTi0Wb7FDn/pzkJEihWlO4X7gOuSX18HfKcSJyk08ZqrA+1b2QgAAATxSURBVJ4zqyW0Wb7Fzjqu9cS6iFRfKEHBzP4N+AmwyMz2m9k7gc8CF5vZU8DFycehydUBz+1oDW2Wb7GzjrWekYgUK5ThI3f/6xzfuqiqDckj20zjjdcs5dQT20Ibeil21rF2UxORYjXcjOZi1PpkMeUURCSHaM1oLpdKB4ViRbETjlKQEpHIiNyM5roUxcSuZjSLSDEUFMpIiV0RqXUKCmWkjWpEpNYpKJSRNqoRkVoXpRnNNU8b1YhIrVNQKLNUYldEpBZp+EhERAIKCiIiElBQEBGRgIKCiIgEFBRERCRQ02sfmVkP8Puw25H0EuD5sBsRUfpsctNnk58+n9yO57N53t0vzfaNmg4KUWJmO9x9WdjtiCJ9Nrnps8lPn09ulfpsNHwkIiIBBQUREQkoKJTPprAbEGH6bHLTZ5OfPp/cKvLZKKcgIiIB3SmIiEhAQUFERAIKCiUws6+aWbeZPZn23Elm9pCZPZX8e3aYbQyLmS0ws4fN7Ndm9isze1/y+Yb/fMxshpn9zMx+kfxs/j75/Blm9tPkZ/NtM2vYDTjMLG5mu8zs/uRjfTaAme01s1+a2c/NbEfyuYr8TikolObrQObEj48B29x9IbAt+bgRjQEfcvdXAK8D/sbMFqPPB2AYuNDdXw28BrjUzF4HrAc+n/xs+oB3htjGsL0P+HXaY302L7rA3V+TNjehIr9TCgolcPcfAn/MePoyYHPy683A5VVtVES4+0F3fzz59QATv+Cnoc8Hn3Ak+bA5+ceBC4F7ks835GcDYGbzgTcB/5J8bOizyaciv1MKCuUzz90PwkTHCMwNuT2hM7PTgbOBn6LPBwiGR34OdAMPAU8D/e4+lnzJfiaCaCP6AvBRIJF8PAd9NikOPGhmO83sxuRzFfmd0s5rUhFm1g5sAd7v7ocnLvrE3ceB15hZJ7AVeEW2l1W3VeEzs5VAt7vvNLPzU09neWnDfTZJK9z9OTObCzxkZr+p1Il0p1A+h8zsFIDk390htyc0ZtbMRED4prvfm3xan08ad+8HHmEi79JpZqkLtPnAc2G1K0QrgLeY2V7gW0wMG30BfTYAuPtzyb+7mbiYWE6FfqcUFMrnPuC65NfXAd8JsS2hSY4D/yvwa3f/v2nfavjPx8y6kncImFkb8BdM5FweBq5MvqwhPxt3/7i7z3f304G/An7g7m9Hnw1mNsvMOlJfA5cAT1Kh3ynNaC6Bmf0bcD4TS9ceAj4F/AdwN/BS4Fngre6emYyue2b258B/Ab/kxbHhTzCRV2joz8fMljCREIwzcUF2t7v/g5n9CRNXxycBu4Cr3X04vJaGKzl89GF3X6nPBpKfwdbkwybgLnf/jJnNoQK/UwoKIiIS0PCRiIgEFBRERCSgoCAiIgEFBRERCSgoiIhIQEFBREQCCgoiIhJQUBApEzN7lZltT3t8jpn9IMw2iRRLk9dEysTMYkyszXOau4+b2cNM7C3xeMhNEymYVkkVKRN3T5jZr4BXmtlC4FkFBKk1Cgoi5fUoEyt+rmHq7nwikaegIFJejzKxXes/u/uBkNsiUjTlFETKKDls9J/AQncfDLs9IsVS9ZFIeb0P+LgCgtQqBQWRMjCzM5NbJLa5++Zp3yASURo+EhGRgO4UREQkoKAgIiIBBQUREQkoKIiISEBBQUREAgoKIiISUFAQEZHA/wdyM0xR+HrWDgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "ffnn = FeedForwardNeuralNetwork()\n",
    "ffnn.fit(X_boston_train, y_boston_train, n_hidden = 8)\n",
    "y_boston_test_hat = ffnn.predict(X_boston_test)\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "sns.scatterplot(y_boston_test, y_boston_test_hat[0])\n",
    "ax.set(xlabel = r'$y$', ylabel = r'$\\hat{y}$', title = r'$y$ vs. $\\hat{y}$')\n",
    "sns.despine()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 638,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9929577464788732"
      ]
     },
     "execution_count": 638,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ffnn = FeedForwardNeuralNetwork()\n",
    "ffnn.fit(X_cancer_train, y_cancer_train, n_hidden = 8,\n",
    "         loss = 'log', f2 = 'sigmoid', seed = 123, lr = 1e-4)\n",
    "y_cancer_test_hat = ffnn.predict(X_cancer_test)\n",
    "np.mean(y_cancer_test_hat.round() == y_cancer_test)\n"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Edit Metadata",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
